{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference in Discrete Bayesian Network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we show a simple example for doing Exact inference in Bayesian Networks using pgmpy. We will be using the Asia network (http://www.bnlearn.com/bnrepository/#asia) for this example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1: Define the model." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Fetch the asia model from the bnlearn repository\n", "\n", "from pgmpy.utils import get_example_model\n", "\n", "asia_model = get_example_model(\"asia\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Nodes: ['asia', 'tub', 'smoke', 'lung', 'bronc', 'either', 'xray', 'dysp']\n", "Edges: [('asia', 'tub'), ('tub', 'either'), ('smoke', 'lung'), ('smoke', 'bronc'), ('lung', 'either'), ('bronc', 'dysp'), ('either', 'xray'), ('either', 'dysp')]\n" ] }, { "data": { "text/plain": [ "[,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Nodes: \", asia_model.nodes())\n", "print(\"Edges: \", asia_model.edges())\n", "asia_model.get_cpds()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you would like to create a model from scratch, please refer to the Creating Bayesian Networks notebook: https://github.com/pgmpy/pgmpy/blob/dev/examples/Creating%20a%20Bayesian%20Network.ipynb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Initialize the inference class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Currently, pgmpy support two algorithms for inference: 1. Variable Elimination and, 2. Belief Propagation. Both of these are exact inferece algorithms. The following example uses `VariableElimination` but `BeliefPropagation` has an identifcal API, so all the methods show below would also work for `BeliefPropagation`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Initializing the VariableElimination class\n", "\n", "from pgmpy.inference import VariableElimination\n", "\n", "asia_infer = VariableElimination(asia_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 3: Doing Inference using hard evidence" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finding Elimination Order: : : 0it [00:00, ?it/s]\n", "0it [00:00, ?it/s]\u001b[A\n", "\n", "0it [00:00, ?it/s]\u001b[A\n", "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\n", "\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\n", "\n", "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\u001b[A\n", "\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+--------------+\n", "| bronc | phi(bronc) |\n", "+============+==============+\n", "| bronc(yes) | 0.3000 |\n", "+------------+--------------+\n", "| bronc(no) | 0.7000 |\n", "+------------+--------------+\n", "+------------+-----------+-------------------+\n", "| bronc | asia | phi(bronc,asia) |\n", "+============+===========+===================+\n", "| bronc(yes) | asia(yes) | 0.0060 |\n", "+------------+-----------+-------------------+\n", "| bronc(yes) | asia(no) | 0.5940 |\n", "+------------+-----------+-------------------+\n", "| bronc(no) | asia(yes) | 0.0040 |\n", "+------------+-----------+-------------------+\n", "| bronc(no) | asia(no) | 0.3960 |\n", "+------------+-----------+-------------------+\n", "+------------+--------------+\n", "| bronc | phi(bronc) |\n", "+============+==============+\n", "| bronc(yes) | 0.3000 |\n", "+------------+--------------+\n", "| bronc(no) | 0.7000 |\n", "+------------+--------------+\n", "+-----------+-------------+\n", "| asia | phi(asia) |\n", "+===========+=============+\n", "| asia(yes) | 0.0100 |\n", "+-----------+-------------+\n", "| asia(no) | 0.9900 |\n", "+-----------+-------------+\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Computing the probability of bronc given smoke=no.\n", "q = asia_infer.query(variables=[\"bronc\"], evidence={\"smoke\": \"no\"})\n", "print(q)\n", "\n", "# Computing the joint probability of bronc and asia given smoke=yes\n", "q = asia_infer.query(variables=[\"bronc\", \"asia\"], evidence={\"smoke\": \"yes\"})\n", "print(q)\n", "\n", "# Computing the probabilities (not joint) of bronc and asia given smoke=no\n", "q = asia_infer.query(variables=[\"bronc\", \"asia\"], evidence={\"smoke\": \"no\"}, joint=False)\n", "for factor in q.values():\n", " print(factor)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", "\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'bronc': 'no'}\n", "{'bronc': 'yes', 'asia': 'no'}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Computing the MAP of bronc given smoke=no.\n", "q = asia_infer.map_query(variables=[\"bronc\"], evidence={\"smoke\": \"no\"})\n", "print(q)\n", "\n", "# Computing the MAP of bronc and asia given smoke=yes\n", "q = asia_infer.map_query(variables=[\"bronc\", \"asia\"], evidence={\"smoke\": \"yes\"})\n", "print(q)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 5: Inference using virtual evidence" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finding Elimination Order: : : 0it [00:04, ?it/s]\n", "Finding Elimination Order: : : 0it [00:04, ?it/s]\n", "Finding Elimination Order: : : 0it [00:02, ?it/s]\n", "Finding Elimination Order: : : 0it [00:02, ?it/s]\n", "\n", "0it [00:00, ?it/s]\u001b[A\n", "Finding Elimination Order: : : 0it [00:00, ?it/s]\u001b[A\n", "\n", "0it [00:00, ?it/s]\u001b[A\u001b[A" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+--------------+\n", "| bronc | phi(bronc) |\n", "+============+==============+\n", "| bronc(yes) | 0.3000 |\n", "+------------+--------------+\n", "| bronc(no) | 0.7000 |\n", "+------------+--------------+\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "lung_virt_evidence = TabularCPD(variable=\"lung\", variable_card=2, values=[[0.4], [0.6]])\n", "\n", "# Query with hard evidence smoke = no and virtual evidence lung = [0.4, 0.6]\n", "q = asia_infer.query(\n", " variables=[\"bronc\"], evidence={\"smoke\": \"no\"}, virtual_evidence=[lung_virt_evidence]\n", ")\n", "print(q)\n", "\n", "# Query with hard evidence smoke = no and virtual evidences lung = [0.4, 0.6] and bronc = [0.3, 0.7]\n", "lung_virt_evidence = TabularCPD(variable=\"lung\", variable_card=2, values=[[0.4], [0.7]])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----------+------------+-----------+\n", "| smoke | smoke(yes) | smoke(no) |\n", "+-----------+------------+-----------+\n", "| lung(yes) | 0.1 | 0.01 |\n", "+-----------+------------+-----------+\n", "| lung(no) | 0.9 | 0.99 |\n", "+-----------+------------+-----------+\n" ] } ], "source": [ "print(asia_model.get_cpds(\"lung\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 4: Troubleshooting for slow inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the case of large models, or models in which variables have a lot of states, inference can be quite slow. Some of the ways to deal with it are:\n", "\n", "1. Reduce the number of states of variables by combining states together.\n", "2. Try a different elimination order by specifying `elimination_order` argument. Possible options are: MinFill, MinNeighbors, MinWeight, WeightedMinFill. \n", "3. Try a custom elimination order: The implemented heuristics for computing the elimination order might not be efficient in every case. If you can think of a more efficient order, you can also pass it as a list to the `elimination_order` argument.\n", "4. If it is still too slow, try using approximate inference using sampling algorithms." ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 1 }